#!/usr/bin/env jq
# compile-config-2.02-learning_inference -- Adds processes for performing inference with the grounded factor graph
##

include "constants";
include "sql";

# skip adding learning/inference processes unless one for grounding is there
if .deepdive_.execution.processes | has("process/grounding/combine_factorgraph") | not then . else

.deepdive_ as $deepdive
| .deepdive_.execution.processes += {

    # learning weights and doing inference (since we had to load the graph anyway)
    "process/model/learning": {
        dependencies_: ["model/factorgraph"],
        output_: ["model/weights", "data/model/weights"],
        style: "cmd_extractor",
        cmd: "
        export DEEPDIVE_LOAD_FORMAT=tsv
        : ${DEEPDIVE_NUM_PROCESSES:=$(nproc --ignore=1)}
        : ${DEEPDIVE_RESULTS_DIR:=\"$DEEPDIVE_APP\"/run/model/results}
        : ${DEEPDIVE_FACTORGRAPH_DIR:=\"$DEEPDIVE_APP\"/run/model/factorgraph}
        export DEEPDIVE_NUM_PROCESSES

        # defaults to handling all shards, which can be overridden to be a subset
        : ${DEEPDIVE_FACTORGRAPH_SHARDS:=$(seq 0 \($deepdive.sampler.partitions - 1))}

        # allow extra runtime args for sampler
        : ${DEEPDIVE_SAMPLER_ARGS:=}

        # run inference engine for learning and inference on each shard
        for pid in $DEEPDIVE_FACTORGRAPH_SHARDS; do
            results_dir=\"$DEEPDIVE_RESULTS_DIR\"/$pid
            mkdir -p \"$results_dir\"
            fg_dir=\"$DEEPDIVE_FACTORGRAPH_DIR\"/$pid
            [[ -d \"$fg_dir\" ]] || error \"$fg_dir: No factorgraph found\"
            cd \"$fg_dir\"

            # decide which set of weights to start learning from
            if [[ $pid -eq 0 ]]; then
                # start from cold weights for the first shard
                weights_hot=weights
            else
                # use hot weights learned from previous shard
                weights_hot=../$(($pid - 1))/weights.learned
            fi
            [[ $weights_hot -ef weights.to_learn ]] || ln -sfn $weights_hot weights.to_learn

            # if there are multiple shards, we defer inference (-i 0)
            run-sampler \\
                    weights=weights.to_learn \\
                    results_dir=\"$results_dir\" \\
                -- \\
                    \($deepdive.sampler.sampler_cmd | @sh) \\
                    \($deepdive.sampler.sampler_args // "") \\
                    \(if $deepdive.sampler.partitions > 1 then "-i 0" else "" end) \\
                    $DEEPDIVE_SAMPLER_ARGS \\
                    #

            \(if $deepdive.sampler.partitions > 1 then "
            # make sure no stale inference result is laying around unless inference is done right after learning above
            rm -f \"$results_dir\"/inference_result.out.text
            " else "" end)

            # TODO maybe this database interaction for the weights after
            # learning each shard is unnecessary if we have the sampler checkpoint
            # them directly, and we could also keep max two copies

            # load weights to database
            deepdive create table \(deepdiveInferenceResultWeightsTable | @sh) \\
                wid:BIGINT:'PRIMARY KEY' \\
                weight:'DOUBLE PRECISION' \\
                #
            cat \"$results_dir\"/inference_result.out.weights.text |
            tr \(" "|@sh) \("\\t"|@sh) |
            deepdive load \(deepdiveInferenceResultWeightsTable | @sh) /dev/stdin
            deepdive db analyze \(deepdiveInferenceResultWeightsTable | @sh)
            mv -f \"$results_dir\"/inference_result.out.weights.text \"$results_dir\"/inference_result.out.weights.text.learned

            \(if $deepdive.sampler.partitions > 1 then "
                # dump weights back out for next-shard learning or inference step
                mkdir -p weights.learned
                deepdive compute execute \\
                    input_sql=\(
                        { SELECT:
                            [ { table: "w", column: "wid" }
                            , { expr: "CASE WHEN isfixed THEN 1 ELSE 0 END" }
                            , { expr: "COALESCE(r.weight, w.initvalue, 0)" }
                            ]
                        , FROM: [ { alias: "w", table: deepdiveGlobalWeightsTable } ]
                        , JOIN: { LEFT_OUTER: { alias: "r", table: deepdiveInferenceResultWeightsTable }
                                , ON: { eq: [ { table: "r", column: "wid" }
                                            , { table: "w", column: "wid" }
                                            ] }
                                }
                        } | asSql | asPrettySqlArg) \\
                    command=\("
                        sampler-dw text2bin weight /dev/stdin /dev/stdout /dev/null |
                        pbzip2 >weights.learned/weights.part-${DEEPDIVE_CURRENT_PROCESS_INDEX}.bin.bz2
                    " | @sh) \\
                    output_relation=

                # expose the last shard's weights for the inference phase
                [[ $pid -lt \($deepdive.sampler.partitions - 1) ]] ||
                    ln -sfn $pid/weights.learned ../learned_weights

            " else "" end)

        done

        deepdive create view \(deepdiveInferenceResultWeightsMappingView | @sh) as \(
            { SELECT:
                [ { expr: "\"w\".*" }
                , { table: "r", column: "weight" }
                ]
            , FROM: { alias: "w", table: deepdiveGlobalWeightsTable }
            , JOIN: { INNER: { alias: "r", table: deepdiveInferenceResultWeightsTable }
                    , ON: { eq: [ { table: "r", column: "wid" }
                                , { table: "w", column: "wid" }
                                ] } }
            , ORDER_BY: { expr: "ABS(\"r\".\"weight\")", order: "DESC" }
            } | asSql | asPrettySqlArg)
        "
    },

    # performing inference
    "process/model/inference": {
        dependencies_: ["model/factorgraph", "model/weights"],
        output_: ["model/probabilities"],
        style: "cmd_extractor",
        cmd: "
        : ${DEEPDIVE_NUM_PROCESSES:=$(nproc --ignore=1)}
        : ${DEEPDIVE_RESULTS_DIR:=\"$DEEPDIVE_APP\"/run/model/results}
        : ${DEEPDIVE_FACTORGRAPH_DIR:=\"$DEEPDIVE_APP\"/run/model/factorgraph}
        export DEEPDIVE_NUM_PROCESSES

        # defaults to handling all shards, which can be overridden to be a subset
        : ${DEEPDIVE_FACTORGRAPH_SHARDS:=$(seq 0 \($deepdive.sampler.partitions - 1))}

        # allow extra runtime args for sampler
        : ${DEEPDIVE_SAMPLER_ARGS:=}

        # need to run inference if learning step was deferred because there are multiple shards
        # or because different weights are used
        for pid in $DEEPDIVE_FACTORGRAPH_SHARDS; do
            results_dir=\"$DEEPDIVE_RESULTS_DIR\"/$pid
            mkdir -p \"$results_dir\"
            fg_dir=\"$DEEPDIVE_FACTORGRAPH_DIR\"/$pid
            [[ -d \"$fg_dir\" ]] || error \"$fg_dir: No factorgraph found\"
            cd \"$fg_dir\"

            # skip if inference was already done
            ! [[ -e \"$results_dir\"/inference_result.out.text ]] || continue

            # run sampler for inference with given weights without learning
            ln -sfn ../learned_weights weights.to_infer

            run-sampler \\
                    weights=weights.to_infer \\
                    results_dir=\"$results_dir\" \\
                -- \\
                    \($deepdive.sampler.sampler_cmd | @sh) \\
                    \($deepdive.sampler.sampler_args // "") \\
                    -l 0 \\
                    $DEEPDIVE_SAMPLER_ARGS \\
                    #
        done
        "
    },

    "process/model/load_probabilities_prepare": {
        style: "cmd_extractor",
        cmd: "
            # prepare loading probabilities from inference results to the database
            deepdive create table \(deepdiveInferenceResultVariablesTable | @sh) \\
                vid:BIGINT \\
                cid:BIGINT \\
                prb:'DOUBLE PRECISION' \\
                #
        "
    },


    "process/model/load_inference_results": {
        dependencies_: ["process/model/load_probabilities_prepare"],
        input_: ["model/probabilities"],
        style: "cmd_extractor",
        cmd: "
            : ${DEEPDIVE_RESULTS_DIR:=\"$DEEPDIVE_APP\"/run/model/results}

            # defaults to handling all shards, which can be overridden to be a subset
            : ${DEEPDIVE_FACTORGRAPH_SHARDS:=$(seq 0 \($deepdive.sampler.partitions - 1))}

            # load probabilities to database
            for pid in $DEEPDIVE_FACTORGRAPH_SHARDS; do
                cd \"$DEEPDIVE_RESULTS_DIR\"/$pid
                cat inference_result.out.text |
                pid=$pid deepdive env restore_partitioned_vids |
                DEEPDIVE_LOAD_FORMAT=tsv \\
                deepdive load \(deepdiveInferenceResultVariablesTable | @sh) /dev/stdin
            done
        "
    },


    "process/model/load_probabilities": {
        dependencies_: ["process/model/load_inference_results"],
        output_: ["data/model/probabilities"],
        style: "cmd_extractor",
        cmd: "
            deepdive db analyze \(deepdiveInferenceResultVariablesTable | @sh)

            # create a view for each app schema variable
            \([ $deepdive.schema.variables_[] | "
                deepdive create view \("\(.variablesTable)_inference" | @sh) as \(
                { SELECT:
                    [ { expr: "\"v\".*" }
                    , { alias: deepdiveVariableExpectationColumn,   table: "r", column: "prb" }
                    , { alias: deepdiveVariableInternalLabelColumn, table: "i", column: deepdiveVariableInternalLabelColumn }
                    , { alias: deepdiveVariableIdColumn,            table: "i", column: deepdiveVariableIdColumn }
                    ]
                , FROM: [ { alias: "v", table: .variablesTable } ]
                , JOIN:
                    # variable ids
                    [ { LEFT_OUTER: { alias: "i", table: .variablesIdsTable }
                      , ON: { and:  [ .variablesKeyColumns[]
                                    | { eq: [ { table: "i", column: . }
                                            , { table: "v", column: . }
                                            ] }
                                    ] } }
                    , if .variableType == "boolean" then
                    # inference results
                      { LEFT_OUTER: { alias: "r", table: deepdiveInferenceResultVariablesTable }
                      , ON: { eq:   [ { table: "r", column: "vid" }
                                    , { table: "i", column: deepdiveVariableIdColumn }
                                    ] } }
                    else
                    # category ids are necessary to find the inference result corresponding to the variable
                      { LEFT_OUTER: { alias: "c", table: .variablesCategoriesTable }
                      , ON: { and:  [ .variablesCategoryColumns[]
                                    | { eq: [ { table: "c", column: "_\(.)" }
                                            , { table: "v", column: . }
                                            ] }
                                    ] } }
                    # inference results
                    , { LEFT_OUTER: { alias: "r", table: deepdiveInferenceResultVariablesTable }
                      , ON: { and:  [ { eq: [ { table: "r", column: "vid" }
                                            , { table: "i", column: deepdiveVariableIdColumn }
                                            ] }
                                    , { eq: [ { table: "r", column: "cid" }
                                            , { table: "c", column: "cid" }
                                            ] }
                                    ] } }
                    end
                    ]
                , ORDER_BY:
                    { expr: { table: "r", column: "prb" }
                    , order: "DESC"
                    }
                } | asSql | @sh)"
            ] | join("\n"))
        "
    }

}

end
